import hashlib
import json
import secrets
from datetime import datetime
from datetime import datetime as dt
from datetime import timezone
from typing import Any, List, Literal, Optional, cast

from pydantic import BaseModel

import litellm
from litellm._logging import verbose_proxy_logger
from litellm.constants import MAX_STRING_LENGTH_PROMPT_IN_DB, REDACTED_BY_LITELM_STRING
from litellm.litellm_core_utils.core_helpers import get_litellm_metadata_from_kwargs
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
from litellm.proxy._types import SpendLogsMetadata, SpendLogsPayload
from litellm.proxy.utils import PrismaClient, hash_token
from litellm.types.utils import (
    StandardLoggingGuardrailInformation,
    StandardLoggingMCPToolCall,
    StandardLoggingModelInformation,
    StandardLoggingPayload,
    StandardLoggingVectorStoreRequest,
    VectorStoreSearchResponse,
)
from litellm.utils import get_end_user_id_for_cost_tracking


def _is_master_key(api_key: str, _master_key: Optional[str]) -> bool:
    if _master_key is None:
        return False

    ## string comparison
    is_master_key = secrets.compare_digest(api_key, _master_key)
    if is_master_key:
        return True

    ## hash comparison
    is_master_key = secrets.compare_digest(api_key, hash_token(_master_key))
    if is_master_key:
        return True

    return False


def _get_spend_logs_metadata(
    metadata: Optional[dict],
    applied_guardrails: Optional[List[str]] = None,
    batch_models: Optional[List[str]] = None,
    mcp_tool_call_metadata: Optional[StandardLoggingMCPToolCall] = None,
    vector_store_request_metadata: Optional[
        List[StandardLoggingVectorStoreRequest]
    ] = None,
    guardrail_information: Optional[List[StandardLoggingGuardrailInformation]] = None,
    usage_object: Optional[dict] = None,
    model_map_information: Optional[StandardLoggingModelInformation] = None,
    cold_storage_object_key: Optional[str] = None,
) -> SpendLogsMetadata:
    if metadata is None:
        return SpendLogsMetadata(
            user_api_key=None,
            user_api_key_alias=None,
            user_api_key_team_id=None,
            user_api_key_org_id=None,
            user_api_key_user_id=None,
            user_api_key_team_alias=None,
            spend_logs_metadata=None,
            requester_ip_address=None,
            additional_usage_values=None,
            applied_guardrails=None,
            status=None or "success",
            error_information=None,
            proxy_server_request=None,
            batch_models=None,
            mcp_tool_call_metadata=None,
            vector_store_request_metadata=None,
            model_map_information=None,
            usage_object=None,
            guardrail_information=None,
            cold_storage_object_key=cold_storage_object_key,
        )
    verbose_proxy_logger.debug(
        "getting payload for SpendLogs, available keys in metadata: "
        + str(list(metadata.keys()))
    )

    # Filter the metadata dictionary to include only the specified keys
    clean_metadata = SpendLogsMetadata(
        **{  # type: ignore
            key: metadata[key]
            for key in SpendLogsMetadata.__annotations__.keys()
            if key in metadata
        }
    )
    clean_metadata["applied_guardrails"] = applied_guardrails
    clean_metadata["batch_models"] = batch_models
    clean_metadata["mcp_tool_call_metadata"] = mcp_tool_call_metadata
    clean_metadata["vector_store_request_metadata"] = (
        _get_vector_store_request_for_spend_logs_payload(vector_store_request_metadata)
    )
    clean_metadata["guardrail_information"] = guardrail_information
    clean_metadata["usage_object"] = usage_object
    clean_metadata["model_map_information"] = model_map_information
    clean_metadata["cold_storage_object_key"] = cold_storage_object_key

    return clean_metadata


def generate_hash_from_response(response_obj: Any) -> str:
    """
    Generate a stable hash from a response object.

    Args:
        response_obj: The response object to hash (can be dict, list, etc.)

    Returns:
        A hex string representation of the MD5 hash
    """
    try:
        # Create a stable JSON string of the entire response object
        # Sort keys to ensure consistent ordering
        json_str = json.dumps(response_obj, sort_keys=True)

        # Generate a hash of the response object
        unique_hash = hashlib.md5(json_str.encode()).hexdigest()
        return unique_hash
    except Exception:
        # Return a fallback hash if serialization fails
        return hashlib.md5(str(response_obj).encode()).hexdigest()


def get_spend_logs_id(
    call_type: str, response_obj: dict, kwargs: dict
) -> Optional[str]:
    if call_type == "aretrieve_batch" or call_type == "acreate_file":
        # Generate a hash from the response object
        id: Optional[str] = generate_hash_from_response(response_obj)
    else:
        id = cast(Optional[str], response_obj.get("id")) or cast(
            Optional[str], kwargs.get("litellm_call_id")
        )
    return id


def _extract_usage_for_ocr_call(response_obj: Any, response_obj_dict: dict) -> dict:
    """
    Extract usage information for OCR/AOCR calls.

    OCR responses use usage_info (with pages_processed) instead of token-based usage.

    Args:
        response_obj: The raw response object (can be dict, BaseModel, or other)
        response_obj_dict: Dictionary representation of the response object

    Returns:
        A dict with prompt_tokens=0, completion_tokens=0, total_tokens=0,
        and pages_processed from usage_info.
    """
    usage_info = None

    # Try to extract usage_info from dict
    if isinstance(response_obj_dict, dict) and "usage_info" in response_obj_dict:
        usage_info = response_obj_dict.get("usage_info")

    # Try to extract usage_info from object attributes if not found in dict
    if not usage_info and hasattr(response_obj, "usage_info"):
        usage_info = response_obj.usage_info
        if hasattr(usage_info, "model_dump"):
            usage_info = usage_info.model_dump()
        elif hasattr(usage_info, "__dict__"):
            usage_info = vars(usage_info)

    # For OCR, we track pages instead of tokens
    if usage_info is not None:
        # Handle dict or object with attributes
        if isinstance(usage_info, dict):
            result = {
                "prompt_tokens": 0,  # OCR doesn't use traditional tokens
                "completion_tokens": 0,
                "total_tokens": 0,
            }
            # Add all fields from usage_info, including pages_processed
            for key, value in usage_info.items():
                result[key] = value
            # Ensure pages_processed exists
            if "pages_processed" not in result:
                result["pages_processed"] = 0
            return result
        else:
            return {
                "prompt_tokens": 0,
                "completion_tokens": 0,
                "total_tokens": 0,
                "pages_processed": 0,
            }
    else:
        return {}


def get_logging_payload(  # noqa: PLR0915
    kwargs, response_obj, start_time, end_time
) -> SpendLogsPayload:
    from litellm.proxy.proxy_server import general_settings, master_key

    if kwargs is None:
        kwargs = {}

    if response_obj is None:
        response_obj = {}
    elif not isinstance(response_obj, BaseModel) and not isinstance(response_obj, dict):
        response_obj = {"result": str(response_obj)}
    # standardize this function to be used across, s3, dynamoDB, langfuse logging
    litellm_params = kwargs.get("litellm_params", {})
    metadata = get_litellm_metadata_from_kwargs(kwargs)
    completion_start_time = kwargs.get("completion_start_time", end_time)
    call_type = kwargs.get("call_type")
    cache_hit = kwargs.get("cache_hit", False)

    # Convert response_obj to dict first
    if isinstance(response_obj, dict):
        response_obj_dict = response_obj
    elif isinstance(response_obj, BaseModel):
        response_obj_dict = response_obj.model_dump()
    else:
        response_obj_dict = {}

    # Handle OCR responses which use usage_info instead of usage
    if call_type in ["ocr", "aocr"]:
        usage = _extract_usage_for_ocr_call(response_obj, response_obj_dict)
    else:
        # Use response_obj_dict instead of response_obj to avoid calling .get() on Pydantic models
        usage = response_obj_dict.get("usage", None) or {}
        if isinstance(usage, litellm.Usage):
            usage = dict(usage)

    id = get_spend_logs_id(call_type or "acompletion", response_obj_dict, kwargs)
    standard_logging_payload = cast(
        Optional[StandardLoggingPayload], kwargs.get("standard_logging_object", None)
    )

    end_user_id = get_end_user_id_for_cost_tracking(litellm_params)

    api_key = metadata.get("user_api_key", "")

    standard_logging_prompt_tokens: int = 0
    standard_logging_completion_tokens: int = 0
    standard_logging_total_tokens: int = 0
    if standard_logging_payload is not None:
        standard_logging_prompt_tokens = standard_logging_payload.get(
            "prompt_tokens", 0
        )
        standard_logging_completion_tokens = standard_logging_payload.get(
            "completion_tokens", 0
        )
        standard_logging_total_tokens = standard_logging_payload.get("total_tokens", 0)
    if api_key is not None and isinstance(api_key, str):
        if api_key.startswith("sk-"):
            # hash the api_key
            api_key = hash_token(api_key)
        if (
            _is_master_key(api_key=api_key, _master_key=master_key)
            and general_settings.get("disable_adding_master_key_hash_to_db") is True
        ):
            api_key = "litellm_proxy_master_key"  # use a known alias, if the user disabled storing master key in db

    if (
        standard_logging_payload is not None
    ):  # [TODO] migrate completely to sl payload. currently missing pass-through endpoint data
        api_key = (
            api_key
            or standard_logging_payload["metadata"].get("user_api_key_hash")
            or ""
        )
        end_user_id = end_user_id or standard_logging_payload["metadata"].get(
            "user_api_key_end_user_id"
        )
    # BUG FIX: Don't overwrite api_key when standard_logging_payload is None
    # The api_key was already extracted from metadata (line 243) and hashed (lines 256-259)
    request_tags = (
        json.dumps(metadata.get("tags", []))
        if isinstance(metadata.get("tags", []), list)
        else "[]"
    )
    if (
        standard_logging_payload is not None
        and standard_logging_payload.get("request_tags") is not None
    ):  # use 'tags' from standard logging payload instead
        request_tags = json.dumps(standard_logging_payload["request_tags"])
    if (
        _is_master_key(api_key=api_key, _master_key=master_key)
        and general_settings.get("disable_adding_master_key_hash_to_db") is True
    ):
        api_key = "litellm_proxy_master_key"  # use a known alias, if the user disabled storing master key in db

    _model_id = metadata.get("model_info", {}).get("id", "")
    _model_group = metadata.get("model_group", "")

    # clean up litellm metadata
    clean_metadata = _get_spend_logs_metadata(
        metadata,
        applied_guardrails=(
            standard_logging_payload["metadata"].get("applied_guardrails", None)
            if standard_logging_payload is not None
            else None
        ),
        batch_models=(
            standard_logging_payload.get("hidden_params", {}).get("batch_models", None)
            if standard_logging_payload is not None
            else None
        ),
        mcp_tool_call_metadata=(
            standard_logging_payload["metadata"].get("mcp_tool_call_metadata", None)
            if standard_logging_payload is not None
            else None
        ),
        vector_store_request_metadata=(
            standard_logging_payload["metadata"].get(
                "vector_store_request_metadata", None
            )
            if standard_logging_payload is not None
            else None
        ),
        usage_object=(
            standard_logging_payload["metadata"].get("usage_object", None)
            if standard_logging_payload is not None
            else None
        ),
        model_map_information=(
            standard_logging_payload["model_map_information"]
            if standard_logging_payload is not None
            else None
        ),
        guardrail_information=(
            standard_logging_payload.get("guardrail_information", None)
            if standard_logging_payload is not None
            else None
        ),
        cold_storage_object_key=(
            standard_logging_payload["metadata"].get("cold_storage_object_key", None)
            if standard_logging_payload is not None
            else None
        ),
    )

    special_usage_fields = ["completion_tokens", "prompt_tokens", "total_tokens"]
    additional_usage_values = {}
    for k, v in usage.items():
        if k not in special_usage_fields:
            if isinstance(v, BaseModel):
                v = v.model_dump()
            additional_usage_values.update({k: v})
    clean_metadata["additional_usage_values"] = additional_usage_values

    if litellm.cache is not None:
        cache_key = litellm.cache.get_cache_key(**kwargs)
    else:
        cache_key = "Cache OFF"
    if cache_hit is True:
        import time

        id = f"{id}_cache_hit{time.time()}"  # SpendLogs does not allow duplicate request_id

    mcp_namespaced_tool_name = None
    mcp_tool_call_metadata: Optional[StandardLoggingMCPToolCall] = clean_metadata.get(
        "mcp_tool_call_metadata"
    )
    if mcp_tool_call_metadata is not None:
        mcp_namespaced_tool_name = mcp_tool_call_metadata.get(
            "namespaced_tool_name", None
        )

    try:
        payload: SpendLogsPayload = SpendLogsPayload(
            request_id=str(id),
            call_type=call_type or "",
            api_key=str(api_key),
            cache_hit=str(cache_hit),
            startTime=_ensure_datetime_utc(start_time),
            endTime=_ensure_datetime_utc(end_time),
            completionStartTime=_ensure_datetime_utc(completion_start_time),
            model=kwargs.get("model", "") or "",
            user=metadata.get("user_api_key_user_id", "") or "",
            team_id=metadata.get("user_api_key_team_id", "") or "",
            organization_id=metadata.get("user_api_key_org_id") or "",
            metadata=safe_dumps(clean_metadata),
            cache_key=cache_key,
            spend=kwargs.get("response_cost", 0),
            total_tokens=usage.get("total_tokens", standard_logging_total_tokens),
            prompt_tokens=usage.get("prompt_tokens", standard_logging_prompt_tokens),
            completion_tokens=usage.get(
                "completion_tokens", standard_logging_completion_tokens
            ),
            request_tags=request_tags,
            end_user=end_user_id or "",
            api_base=litellm_params.get("api_base", ""),
            model_group=_model_group,
            model_id=_model_id,
            mcp_namespaced_tool_name=mcp_namespaced_tool_name,
            requester_ip_address=clean_metadata.get("requester_ip_address", None),
            custom_llm_provider=kwargs.get("custom_llm_provider", ""),
            messages=_get_messages_for_spend_logs_payload(
                standard_logging_payload=standard_logging_payload, metadata=metadata
            ),
            response=_get_response_for_spend_logs_payload(standard_logging_payload),
            proxy_server_request=_get_proxy_server_request_for_spend_logs_payload(
                metadata=metadata, litellm_params=litellm_params
            ),
            session_id=_get_session_id_for_spend_log(
                kwargs=kwargs,
                standard_logging_payload=standard_logging_payload,
            ),
            status=_get_status_for_spend_log(
                metadata=metadata,
            ),
        )

        verbose_proxy_logger.debug(
            "SpendTable: created payload - payload: %s\n\n",
            json.dumps(payload, indent=4, default=str),
        )

        return payload
    except Exception as e:
        verbose_proxy_logger.exception(
            "Error creating spendlogs object - {}".format(str(e))
        )
        raise e


def _get_session_id_for_spend_log(
    kwargs: dict,
    standard_logging_payload: Optional[StandardLoggingPayload],
) -> str:
    """
    Get the session id for the spend log.

    This ensures each spend log is associated with a unique session id.

    """
    from litellm._uuid import uuid

    if (
        standard_logging_payload is not None
        and standard_logging_payload.get("trace_id") is not None
    ):
        return str(standard_logging_payload.get("trace_id"))

    # Users can dynamically set the trace_id for each request by passing `litellm_trace_id` in kwargs
    if kwargs.get("litellm_trace_id") is not None:
        return str(kwargs.get("litellm_trace_id"))

    # Ensure we always have a session id, if none is provided
    return str(uuid.uuid4())


def _ensure_datetime_utc(timestamp: datetime) -> datetime:
    """Helper to ensure datetime is in UTC"""
    timestamp = timestamp.astimezone(timezone.utc)
    return timestamp


async def get_spend_by_team_and_customer(
    start_date: dt,
    end_date: dt,
    team_id: str,
    customer_id: str,
    prisma_client: PrismaClient,
):
    sql_query = """
    WITH SpendByModelApiKey AS (
        SELECT
            date_trunc('day', sl."startTime") AS group_by_day,
            COALESCE(tt.team_alias, 'Unassigned Team') AS team_name,
            sl.end_user AS customer,
            sl.model,
            sl.api_key,
            SUM(sl.spend) AS model_api_spend,
            SUM(sl.total_tokens) AS model_api_tokens
        FROM 
            "LiteLLM_SpendLogs" sl
        LEFT JOIN 
            "LiteLLM_TeamTable" tt 
        ON 
            sl.team_id = tt.team_id
        WHERE
            sl."startTime" >= $1::timestamptz AND sl."startTime" < ($2::timestamptz + INTERVAL '1 day')
            AND sl.team_id = $3
            AND sl.end_user = $4
        GROUP BY
            date_trunc('day', sl."startTime"),
            tt.team_alias,
            sl.end_user,
            sl.model,
            sl.api_key
    )
        SELECT
            group_by_day,
            jsonb_agg(jsonb_build_object(
                'team_name', team_name,
                'customer', customer,
                'total_spend', total_spend,
                'metadata', metadata
            )) AS teams_customers
        FROM (
            SELECT
                group_by_day,
                team_name,
                customer,
                SUM(model_api_spend) AS total_spend,
                jsonb_agg(jsonb_build_object(
                    'model', model,
                    'api_key', api_key,
                    'spend', model_api_spend,
                    'total_tokens', model_api_tokens
                )) AS metadata
            FROM 
                SpendByModelApiKey
            GROUP BY
                group_by_day,
                team_name,
                customer
        ) AS aggregated
        GROUP BY
            group_by_day
        ORDER BY
            group_by_day;
    """

    db_response = await prisma_client.db.query_raw(
        sql_query, start_date, end_date, team_id, customer_id
    )
    if db_response is None:
        return []

    return db_response


def _get_messages_for_spend_logs_payload(
    standard_logging_payload: Optional[StandardLoggingPayload],
    metadata: Optional[dict] = None,
) -> str:
    return "{}"


def _sanitize_request_body_for_spend_logs_payload(
    request_body: dict,
    visited: Optional[set] = None,
) -> dict:
    """
    Recursively sanitize request body to prevent logging large base64 strings or other large values.
    Truncates strings longer than MAX_STRING_LENGTH_PROMPT_IN_DB characters and handles nested dictionaries.
    """
    from litellm.constants import LITELLM_TRUNCATED_PAYLOAD_FIELD

    if visited is None:
        visited = set()

    # Get the object's memory address to track visited objects
    obj_id = id(request_body)
    if obj_id in visited:
        return {}
    visited.add(obj_id)

    def _sanitize_value(value: Any) -> Any:
        if isinstance(value, dict):
            return _sanitize_request_body_for_spend_logs_payload(value, visited)
        elif isinstance(value, list):
            return [_sanitize_value(item) for item in value]
        elif isinstance(value, str):
            if len(value) > MAX_STRING_LENGTH_PROMPT_IN_DB:
                # Keep 35% from beginning and 65% from end (end is usually more important)
                # This split ensures we keep more context from the end of conversations
                start_ratio = 0.35
                end_ratio = 0.65

                # Calculate character distribution
                start_chars = int(MAX_STRING_LENGTH_PROMPT_IN_DB * start_ratio)
                end_chars = int(MAX_STRING_LENGTH_PROMPT_IN_DB * end_ratio)

                # Ensure we don't exceed the total limit
                total_keep = start_chars + end_chars
                if total_keep > MAX_STRING_LENGTH_PROMPT_IN_DB:
                    end_chars = MAX_STRING_LENGTH_PROMPT_IN_DB - start_chars

                # If the string length is less than what we want to keep, just truncate normally
                if len(value) <= MAX_STRING_LENGTH_PROMPT_IN_DB:
                    return value

                # Calculate how many characters are being skipped
                skipped_chars = len(value) - total_keep

                # Build the truncated string: beginning + truncation marker + end
                truncated_value = (
                    f"{value[:start_chars]}"
                    f"... ({LITELLM_TRUNCATED_PAYLOAD_FIELD} skipped {skipped_chars} chars) ..."
                    f"{value[-end_chars:]}"
                )
                return truncated_value
            return value
        return value

    return {k: _sanitize_value(v) for k, v in request_body.items()}


def _get_proxy_server_request_for_spend_logs_payload(
    metadata: dict,
    litellm_params: dict,
) -> str:
    """
    Only store if _should_store_prompts_and_responses_in_spend_logs() is True
    """
    if _should_store_prompts_and_responses_in_spend_logs():
        _proxy_server_request = cast(
            Optional[dict], litellm_params.get("proxy_server_request", {})
        )
        if _proxy_server_request is not None:
            _request_body = _proxy_server_request.get("body", {}) or {}
            _request_body = _sanitize_request_body_for_spend_logs_payload(_request_body)
            _request_body_json_str = json.dumps(_request_body, default=str)
            return _request_body_json_str
    return "{}"


def _get_vector_store_request_for_spend_logs_payload(
    vector_store_request_metadata: Optional[List[StandardLoggingVectorStoreRequest]],
) -> Optional[List[StandardLoggingVectorStoreRequest]]:
    """
    If user does not want to store prompts and responses, then remove the content from the vector store request metadata
    """
    if _should_store_prompts_and_responses_in_spend_logs():
        return vector_store_request_metadata

    # if user does not want to store prompts and responses, then remove the content from the vector store request metadata
    if vector_store_request_metadata is None:
        return None
    for vector_store_request in vector_store_request_metadata:
        vector_store_search_response: VectorStoreSearchResponse = (
            vector_store_request.get("vector_store_search_response")
            or VectorStoreSearchResponse()
        )
        response_data = vector_store_search_response.get("data", []) or []
        for response_item in response_data:
            for content_item in response_item.get("content", []) or []:
                if "text" in content_item:
                    content_item["text"] = REDACTED_BY_LITELM_STRING
    return vector_store_request_metadata


def _get_response_for_spend_logs_payload(
    payload: Optional[StandardLoggingPayload],
) -> str:
    if payload is None:
        return "{}"
    if _should_store_prompts_and_responses_in_spend_logs():
        response_obj: Any = payload.get("response")
        if response_obj is None:
            return "{}"

        sanitized_wrapper = _sanitize_request_body_for_spend_logs_payload(
            {"response": response_obj}
        )

        sanitized_response = sanitized_wrapper.get("response", response_obj)

        if sanitized_response is None:
            return "{}"
        if isinstance(sanitized_response, str):
            return sanitized_response
        return safe_dumps(sanitized_response)
    return "{}"


def _should_store_prompts_and_responses_in_spend_logs() -> bool:
    from litellm.proxy.proxy_server import general_settings
    from litellm.secret_managers.main import get_secret_bool

    return (
        general_settings.get("store_prompts_in_spend_logs") is True
        or get_secret_bool("STORE_PROMPTS_IN_SPEND_LOGS") is True
    )


def _get_status_for_spend_log(
    metadata: dict,
) -> Literal["success", "failure"]:
    """
    Get the status for the spend log.

    It's only a failure if metadata.get("status") is "failure"
    """
    _status: Optional[str] = metadata.get("status", None)
    if _status == "failure":
        return "failure"
    return "success"
